import random
import os
import matplotlib.pyplot as plt
import torch.optim as optim
import torch.utils.data
from torch import nn
import numpy as np
import time
import argparse
from data_loader import GetLoader
from torchvision import datasets
from torchvision import transforms
from model import CNNModel


# parse arguments
parser = argparse.ArgumentParser()
parser.add_argument('--lr_x', type=float, default=1e-2)
parser.add_argument('--lr_y', type=float, default=1e-2)
parser.add_argument('--lam', type=float, default=1)
parser.add_argument('--feature_size', type=int, default=100)
parser.add_argument('--epoch', type=int, default=20000)

args = parser.parse_args()

lr_x = args.lr_x
lr_y = args.lr_y
lam = args.lam
feature_size = args.feature_size
n_epoch = args.epoch
source_dataset_name = 'MNIST'
target_dataset_name = 'mnist_m'

source_image_root = os.path.join('dataset', source_dataset_name)
target_image_root = os.path.join('dataset', target_dataset_name)

cuda = True if torch.cuda.is_available() else False


image_size = 28
step = 1000

manual_seed = 8
random.seed(manual_seed)
torch.manual_seed(manual_seed)

# load data

img_transform_source = transforms.Compose([
    transforms.Resize(image_size),
    transforms.ToTensor(),
    transforms.Normalize(mean=(0.1307,), std=(0.3081,))
])

img_transform_target = transforms.Compose([
    transforms.Resize(image_size),
    transforms.ToTensor(),
    transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
])

dataset_source = datasets.MNIST(
    root='dataset',
    train=True,
    transform=img_transform_source,
    download=True
)

dataloader_source = torch.utils.data.DataLoader(
    dataset=dataset_source,
    batch_size=len(dataset_source),
    num_workers=8)

train_list = os.path.join(target_image_root, 'mnist_m_train_labels.txt')

dataset_target = GetLoader(
    data_root=os.path.join(target_image_root, 'mnist_m_train'),
    data_list=train_list,
    transform=img_transform_target
)

dataloader_target = torch.utils.data.DataLoader(
    dataset=dataset_target,
    batch_size=len(dataset_target),
    shuffle=True,
    num_workers=8)

data_source_iter = iter(dataloader_source)
data_target_iter = iter(dataloader_target)

# load model

my_net = CNNModel(feature_size).double()
my_logi = torch.nn.Linear(feature_size, 1, bias=False).double()
my_logi_phi = torch.nn.Linear(feature_size, 1, bias=False).double()

# setup optimizer

net_optimizer = optim.SGD(my_net.parameters(), lr=lr_x)
logi_optimizer = optim.SGD(my_logi_phi.parameters(), lr=0.01, momentum=0.99, nesterov=True)
#logi_optimizer = optim.SGD(my_logi_phi.parameters(), lr=0.02, momentum=0.99, nesterov=True)

loss_class = torch.nn.NLLLoss()
loss_domain = torch.nn.BCEWithLogitsLoss()

if cuda:
    my_net = my_net.cuda()
    my_logi = my_logi.cuda()
    my_logi_phi = my_logi_phi.cuda()
    loss_class = loss_class.cuda()
    loss_domain = loss_domain.cuda()

for p in my_net.parameters():
    p.requires_grad = True

for p in my_logi.parameters():
    p.requires_grad = True

for p in my_logi_phi.parameters():
    p.requires_grad = True

print('Load model done!')


# load source data
print('Loading source data...')
data_source = data_source_iter.next()
s_img, s_label = data_source
s_img = s_img.view(-1,28*28)
s_img = s_img.double()
print('Finish loading source data.')

# load target data
print('Loading target data...')
data_target = data_target_iter.next()
t_img, t_label = data_target
t_img = 0.299 * t_img[:,0,:,:] + 0.587 * t_img[:,1,:,:] + 0.114 * t_img[:,1,:,:]
t_img = t_img.view(-1,28*28)
t_img = t_img.double()
print('Finish loading target data.')

s_domain_label = torch.zeros(len(s_label))
t_domain_label = torch.ones(len(t_img))

if cuda:
    s_img = s_img.cuda()
    t_img = t_img.cuda()
    s_label = s_label.cuda()
    t_label = t_label.cuda()
    s_domain_label = s_domain_label.cuda()
    t_domain_label = t_domain_label.cuda()

P_result = []
x_grad_result = []
y_grad_result = []
time_result = []
accu_result = []

total_time = 0



# training

start_time = time.time()

for epoch in range(n_epoch):

    my_net.zero_grad()
    my_logi.zero_grad()

    s_fea, class_output = my_net(input_data=s_img)
    err_s_label = loss_class(class_output, s_label)

    s_domain_output = my_logi(s_fea)
    err_s_domain = loss_domain(s_domain_output.squeeze(), s_domain_label)

    t_fea, _ = my_net(input_data=t_img)
    t_domain_output = my_logi(t_fea)
    err_t_domain = loss_domain(t_domain_output.squeeze(), t_domain_label)

    err = err_s_label - err_t_domain - err_s_domain - lam / 2 * (my_logi.weight.norm() ** 2)

    err.backward()

    if (epoch+1) % step == 0:
        end_time = time.time()
        total_time += end_time - start_time

        x_grad = torch.cat([para.grad.data.view(-1) for para in my_net.parameters()], 0).norm().cpu().numpy()
        y_grad = my_logi.weight.grad.norm().detach().cpu().numpy()
        x_grad_result.append(x_grad)
        y_grad_result.append(y_grad)

        start_time = time.time()

    net_optimizer.step()

    with torch.no_grad():
        my_logi.weight += my_logi.weight.grad * lr_y

    if (epoch + 1) % step == 0:

        end_time = time.time()
        total_time += end_time - start_time
        time_result.append(total_time)


        ## Perform AGD
        my_logi_phi.load_state_dict(my_logi.state_dict())

        with torch.no_grad():

            s_fea_phi, _ = my_net(input_data=s_img)
            t_fea_phi, _ = my_net(input_data=t_img)


        for i in range(1000):
            my_logi_phi.zero_grad()
            logi_optimizer.zero_grad()

            s_domain_output_phi = my_logi_phi(s_fea_phi)
            t_domain_output_phi = my_logi_phi(t_fea_phi)

            err_s_domain_phi = loss_domain(s_domain_output_phi.squeeze(), s_domain_label)
            err_t_domain_phi = loss_domain(t_domain_output_phi.squeeze(), t_domain_label)

            err_domain = err_s_domain_phi + err_t_domain_phi + lam / 2 * (my_logi.weight.norm() ** 2)

            #err_domain.backward(retain_graph=True)
            err_domain.backward()
            logi_optimizer.step()

        #print()

        ## Compute P function
        with torch.no_grad():

            s_fea, class_output = my_net(input_data=s_img)
            err_s_label = loss_class(class_output, s_label)

            s_domain_output = my_logi_phi(s_fea)
            err_s_domain = loss_domain(s_domain_output.squeeze(), s_domain_label)

            t_fea, _ = my_net(input_data=t_img)
            t_domain_output = my_logi_phi(t_fea)
            err_t_domain = loss_domain(t_domain_output.squeeze(), t_domain_label)

            err = err_s_label - err_t_domain - err_s_domain - lam / 2 * (my_logi.weight.norm() ** 2)
            err = err.detach().cpu().numpy()

        P_result.append(err)

        print('Epoch =', epoch, 'agd grad = ', my_logi_phi.weight.grad.norm().cpu().numpy(), 'x_grad =', x_grad, 'y_grad =', y_grad, 'P =', err, 'Time =', total_time)

        start_time = time.time()

filename = str(lam) + '_' +str(feature_size) + '_' + str(lr_x) + '_' + str(lr_y) + '_gda_mnist_mnist_m.npz'
np.savez(filename, P=P_result, Time=time_result, xgrad=x_grad_result, ygrad=y_grad_result)

